#include "stdafx.h"
#include "ColorSpaceRead.h"

// private function
WORD_SIZE colors2Bases(WORD_SIZE readInColors)
{
    WORD_SIZE readInBases = readInColors & 0x01; //Copy the first bit
    WORD_SIZE mask = 0x01;
    readInColors >>= 0x01; //shift the first encoded base bit
    for (int i = 0; i < CReadInBits::iReadLength; i++) {
        //get the ith digit and do the x
        WORD_SIZE nextBase = (readInBases ^ readInColors) & mask;
        mask <<= 0x01;
        readInBases += (nextBase << 0x01);
    }
    return(readInBases);
}

// Assum the first bits is encoded in readInColor
CReadInBits colors2Bases(CReadInBits readInColors)
{
    CReadInBits readInBases;
    readInBases.UpperBits = colors2Bases(readInColors.UpperBits);
    readInBases.LowerBits = colors2Bases(readInColors.LowerBits);

    return (readInBases);
}
int getSNPtype(CReadInBits readInColors, CReadInBits refInColors)
{
    readInColors = readInColors.getPrefixStr((unsigned int)CReadInBits::iReadLength);
    refInColors = refInColors.getPrefixStr((unsigned int)CReadInBits::iReadLength); // compare only the first readlength bits
    WORD_SIZE upperBitsDiff = readInColors.UpperBits ^ refInColors.UpperBits;
    WORD_SIZE lowerBitsDiff = readInColors.LowerBits ^ refInColors.LowerBits;
    WORD_SIZE a = upperBitsDiff | lowerBitsDiff;

    int diff;
    // magic function to caculate how many ones are there
#ifdef __GNUC__
    // #ifdef AMD
    diff = __builtin_popcountll(a);
#else
    for (diff = 0; a; diff++) {
        a &= a - 1; // clear the least significant bit set
    }
#endif

    // check SNP type case
    if (diff >= 2 && diff <= 3) {
        // (1) Complement Type of SNP, (A <-> T, C<->G)
        // Indicated by two consecutive of color changed R <-> B or Y <-> G
        WORD_SIZE diffStr = upperBitsDiff & lowerBitsDiff;
        WORD_SIZE snpFlag = diffStr & (diffStr >> 1); // Note the first bit set is the position of SNP
        if (snpFlag) {
            return (-1);
        }

        // (2) Transversion Type of SNP, (A <-> C, G <-> T)
        // Indicated by two consecutive of color changed R <-> Y or B <-> G
        diffStr = ~upperBitsDiff & lowerBitsDiff;
        snpFlag = diffStr & (diffStr >> 1);
        if (snpFlag) {
            return (-2);
        }

        // (3) Transition Type of SNP, (A <-> G, C <-> T)
        // Indicated by two consecutive of color changed B <-> Y or R <-> G
        diffStr = upperBitsDiff & ~lowerBitsDiff;
        snpFlag = diffStr & (diffStr >> 1);
        if (snpFlag) {
            return (-3);
        }
    }
    return(diff);
}

bool encodeColorsNas3(const char* colorsStr, CReadInBits& readInColors)
{
    const WORD_SIZE bit = 0x01;
    readInColors.LowerBits = 0;
    readInColors.UpperBits = 0;
    setFirstBase(base2Color(colorsStr[0], colorsStr[1]), readInColors);
    for (int i = 2; ; i++) {
        switch (colorsStr[i]) {
        case '0':
            break;
        case '1':
            readInColors.LowerBits += (bit << (i - 1));
            break;
        case '2':
            readInColors.UpperBits += (bit << (i - 1));
            break;
        case '3':
            readInColors.UpperBits += (bit << (i - 1));
            readInColors.LowerBits += (bit << (i - 1));
            break;
        case 'N':
        case '.': // encode unknown as 3
            printf("Warning! Encode '.' in %s as color '3'\n", colorsStr);
            readInColors.UpperBits += (bit << (i - 1));
            readInColors.LowerBits += (bit << (i - 1));
            break;
        case '\0':
            return (true);
        default:
            return (false);
        }
    }
}

bool encodeColors(const char* colorsStr, CReadInBits& readInColors)
{
    const WORD_SIZE bit = 0x01;
    readInColors.LowerBits = 0;
    readInColors.UpperBits = 0;
    setFirstBase(base2Color(colorsStr[0], colorsStr[1]), readInColors);
    if (colorsStr[1] == '.') {
        return(false);
    }
    for (int i = 2; ; i++) {
        switch (colorsStr[i]) {
        case '0':
            break;
        case '1':
            readInColors.LowerBits += (bit << (i - 1));
            break;
        case '2':
            readInColors.UpperBits += (bit << (i - 1));
            break;
        case '3':
            readInColors.UpperBits += (bit << (i - 1));
            readInColors.LowerBits += (bit << (i - 1));
            break;
        case '\0':
            return (true);
        default:
            return (false);
        }
    }
}

char* decodeColors(char* colorsStr, CReadInBits readInColors)
{
    int i;
    for (i = 0; i < CReadInBits::iReadLength; i++) {
        WORD_SIZE c = (readInColors.UpperBits & 0x01) << 1 | (readInColors.LowerBits & 0x01);
        if (i == 0) {
            switch (c) {
            case 0x00:
                colorsStr[0] = 'A';
                break;
            case 0x01:
                colorsStr[0] = 'C';
                break;
            case 0x02:
                colorsStr[0] = 'G';
                break;
            case 0x03:
                colorsStr[0] = 'T';
                break;
            default:
                colorsStr[0] = 'N';
            }
        } else {
            colorsStr[i] = '0' + (char)(c);
        }
        readInColors.LowerBits >>= 1;
        readInColors.UpperBits >>= 1;
    }
    colorsStr[i] = '\0';
    return(colorsStr);
}

// Correct the single color mismatch and adopted the valid SNP color
// return number of type of SNP it involved.
int correctReadInColorSpace(CReadInBits readInColors, CReadInBits refInColors, CReadInBits& correctedRead)
{
    // printBitsStr(readInColors, CReadInBits::iReadLength);// DEBUG
    // printBitsStr(refInColors, CReadInBits::iReadLength); // DEBUG
    // compare only the first read-length bits
    readInColors = readInColors.getPrefixStr((unsigned int)CReadInBits::iReadLength);
    refInColors  = refInColors.getPrefixStr((unsigned int)CReadInBits::iReadLength);
    correctedRead = refInColors; // Default is the same

    WORD_SIZE upperBitsDiff = readInColors.UpperBits ^ refInColors.UpperBits;
    WORD_SIZE lowerBitsDiff = readInColors.LowerBits ^ refInColors.LowerBits;
    WORD_SIZE d = upperBitsDiff | lowerBitsDiff; //bits string indicating which bits are different
    int diff;
#ifdef AMD
    diff = __builtin_popcountll(d); // magic function to caculate how many ones are there
#else
    for (diff = 0; d; diff++) {
        d &= d - 1; // clear the least significant bit set
    }
#endif
    WORD_SIZE lastBit = SHIFT_LEFT(0x01, (CReadInBits::iReadLength - 1));
    WORD_SIZE diffStr = upperBitsDiff | lowerBitsDiff;
    WORD_SIZE snpBits = diffStr & (diffStr >> 1);
    if (snpBits == 0) { // no consecutive mismatches
        if (diffStr == lastBit) {
            correctedRead = correctReadInColorSpace(readInColors, refInColors, lastBit);
        }
        return(diff);
    } else if (snpBits & (snpBits >> 1)) {
        correctedRead = refInColors;
        return(-1); // Three consecutive mismatches
    } else {
        WORD_SIZE replacedBits = 0x00;
        { // (1) Complement Type of SNP, (A <-> T, C<->G)
            WORD_SIZE diffStr = upperBitsDiff & lowerBitsDiff;
            WORD_SIZE snpBits = diffStr & (diffStr >> 1);
            replacedBits |= ( snpBits | (snpBits << 1) );
        } // Indicated by two consecutive of color changed R <-> B or Y <-> G
        { // (2) Transversion Type of SNP, (A <-> C, G <-> T)
            WORD_SIZE diffStr = ~upperBitsDiff & lowerBitsDiff;
            WORD_SIZE snpBits = diffStr & (diffStr >> 1);
            replacedBits |= ( snpBits | (snpBits << 1) );
        } // Indicated by two consecutive of color changed R <-> Y or B <-> G
        { // (3) Transition Type of SNP, (A <-> G, C <-> T)
            WORD_SIZE diffStr = upperBitsDiff & ~lowerBitsDiff;
            WORD_SIZE snpBits = diffStr & (diffStr >> 1);
            replacedBits |= ( snpBits | (snpBits << 1) );
        } // Indicated by two consecutive of color changed B <-> Y or R <-> G
        // if the bit before last bit is not mismatched

        bool lastBitDiff = ((diffStr & lastBit) != 0);
        if (lastBitDiff && !(diffStr & ( lastBit >> 0x01 ) ) ) {
            // no consecutive mismatches in the end
            replacedBits |= lastBit; // take the last bit from read
        }
        correctedRead = correctReadInColorSpace(readInColors, refInColors, replacedBits);

        // ASSERT
        CReadInBits read = colors2Bases(correctedRead);
        // printBitsStr(read, (unsigned int)CReadInBits::iReadLength);
        CReadInBits ref = colors2Bases(refInColors);
        // printBitsStr(ref, (unsigned int)CReadInBits::iReadLength);
        unsigned int NoSNP = bitsStrNCompare(read, ref, (unsigned int)CReadInBits::iReadLength);
        // assert(NoSNP <= 5);
        // assertSNP(type, refInColors, correctedRead);
        return(diff - (int)NoSNP * 2 + (int)lastBitDiff);
    }
}

char* correctAndDecodeRead \
(CReadInBits read, CReadInBits ref, bool correct, char* caRead, char* caQscore)
{
    if (correct) {
        CReadInBits correctedRead;
        int colorMis = correctReadInColorSpace(read, ref, correctedRead);
        if ( colorMis >= 0 ) {
            colorQV2baseQV(read, correctedRead, caQscore);
        } else {
            // TODO: use DP to get the base inferred by color string.
            // The corrected reads will be the reference reads.
            colorQV2baseQV(read, correctedRead, caQscore);
        }
        colors2Bases(correctedRead).decode(caRead);
    } else {
        decodeColorReadWithPrimer(caRead, read); // decodeColors(caRead, reads);
    }
    return(caRead);
}

// Given strings in bases, return the corresponding color signal in A=0 C=1 G=2, T=3 representation
string readInBases2ColorsInACGT_Format(string readInBases)
{
    char colorsPresentInACGT[MAX_READ_LENGTH];
    CReadInBits r(readInBases.c_str());
    bases2Colors(r).decode(colorsPresentInACGT);
    return(string(colorsPresentInACGT));
}

// Given the color reads in ACGT format, return the corresponding read in 0123 format
string colorReadInACGTto0123Format(string colorReadInACGT)
{
    CReadInBits r(colorReadInACGT.c_str());
    char colorsIn0123Format[MAX_READ_LENGTH];
    // translate into 0123 format
    decodeColors(&(colorsIn0123Format[1]), r);
    colorsIn0123Format[0] = colorsIn0123Format[1];
    colorsIn0123Format[1] = '0';
    return(string(colorsIn0123Format));
}

// TEST
void assertSNP(int SNPType, CReadInBits refInColors, CReadInBits crInColors)
{
    if (SNPType == 0) {
        ASSERT_TRUE(refInColors == crInColors, "MISMATCHES after correction");
    } else {
        CReadInBits ref = colors2Bases(refInColors);
        CReadInBits cr = colors2Bases(crInColors);
        ref = ref.getPrefixStr(CReadInBits::iReadLength);
        cr = cr.getPrefixStr(CReadInBits::iReadLength);
        WORD_SIZE u = (cr.UpperBits ^ ref.UpperBits);
        WORD_SIZE l = (cr.LowerBits ^ ref.LowerBits);

        if (SNPType == 1) {
            ASSERT_TRUE((u & l) > 0,  "Not real complement SNP");
        } else if (SNPType == 2) {
            ASSERT_TRUE((~u & l) > 0, "Not real transverstion SNP");
        } else if (SNPType == 3) {
            ASSERT_TRUE((u & ~l) > 0, "Not real transition SNP");
        } else {
            ASSERT_TRUE(2 == bitsStrCompare(ref, cr), "Not double SNPs");
        }
    }
}

void colorQV2baseQV(CReadInBits readInColors, CReadInBits& correctedRead, char* Qscores)
{
    if (Qscores[0] != '0') { // In case quality score are not available
        colorQV2baseQV(getDiffBits(readInColors, correctedRead), Qscores, \
                       (unsigned int)CReadInBits::iReadLength); // TODO the read length is fixed under 64 now.
    }
}

// The input Q-scores, output is in the Phed char representation.
bool colorQV2baseQV(WORD_SIZE singleColorErrorflag, char* Qscores, unsigned int readLength)
{
    bool negativeQ = false;
    const char PhedScoreShift = 33;
    for (unsigned int i = 0; i < readLength; i++) {
        bool errorColorFlag = ((singleColorErrorflag & 0x01) > 0);
        Qscores[i] -= PhedScoreShift;
        if (errorColorFlag) {
            if (Qscores[i] > 0) {
                Qscores[i] *= (char)-1;
            } else if (Qscores[i] < 0) {
                Qscores[i] = 0;
                negativeQ = true;
            }
        }
        singleColorErrorflag >>= 0x01;
    }
    for (unsigned int i = 0; i < readLength; i++) {
        Qscores[i] = Qscores[i] + Qscores[i + 1];
        if (Qscores[i] < 0) {
            Qscores[i] = 0;
        }
        Qscores[i] += PhedScoreShift;
    }
    return(negativeQ);
}

